{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "04183.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/yananma/5_programs_per_day/blob/master/04183.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Gf2wSpU0cEOz",
        "colab_type": "text"
      },
      "source": [
        "## 5.7 使用重复元素的网络 ( VGG )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IfFnKfQuqA5i",
        "colab_type": "text"
      },
      "source": [
        "### 5.7.1 VGG 块"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "rXEyfCGaqGml",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import torch\n",
        "from torch import nn, optim \n",
        "import d2l \n",
        "import time \n",
        "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
        "\n",
        "\n",
        "def vgg_block(num_convs, in_channels, out_channels):\n",
        "    blk = []\n",
        "    for i in range(num_convs):\n",
        "        if i == 0:\n",
        "            blk.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))\n",
        "        else:\n",
        "            blk.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))\n",
        "        blk.append(nn.ReLU())\n",
        "    blk.append(nn.MaxPool2d(kernel_size=2, stride=2))\n",
        "    return nn.Sequential(*blk)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WEABsZFJsBXo",
        "colab_type": "text"
      },
      "source": [
        "### 5.7.2 VGG 网络"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ixCKsYGfsPcQ",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "conv_arch = ((1, 1, 64), (1, 64, 128), (2, 128, 256), (2, 256, 512), (2, 512, 512))\n",
        "fc_features = 512 * 7 * 7 \n",
        "fc_hidden_units = 4096"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "KzXIHUMncQee",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def vgg(conv_arch, fc_features, fc_hidden_units=4096):\n",
        "    net = nn.Sequential()\n",
        "    for i, (num_convs, in_channels, out_channels) in enumerate(conv_arch):\n",
        "        net.add_module('vgg_block_' + str(i+1), vgg_block(num_convs, in_channels, out_channels))\n",
        "    net.add_module('fc', nn.Sequential(d2l.FlattenLayer(), \n",
        "                    nn.Linear(fc_features, fc_hidden_units), \n",
        "                    nn.ReLU(), \n",
        "                    nn.Dropout(0.5), \n",
        "                    nn.Linear(fc_hidden_units, fc_hidden_units), \n",
        "                    nn.ReLU(), \n",
        "                    nn.Dropout(0.5), \n",
        "                    nn.Linear(fc_hidden_units, 10)\n",
        "                    ))\n",
        "\n",
        "    return net "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "MPqgb5a5du55",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 118
        },
        "outputId": "26722463-31b9-4254-93b2-bb6e4ede1080"
      },
      "source": [
        "net = vgg(conv_arch, fc_features, fc_hidden_units)\n",
        "X = torch.rand(1, 1, 224, 224)\n",
        "\n",
        "for name, blk in net.named_children():\n",
        "    X = blk(X)\n",
        "    print(name, 'output shape: ', X.shape)"
      ],
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "vgg_block_1 output shape:  torch.Size([1, 64, 112, 112])\n",
            "vgg_block_2 output shape:  torch.Size([1, 128, 56, 56])\n",
            "vgg_block_3 output shape:  torch.Size([1, 256, 28, 28])\n",
            "vgg_block_4 output shape:  torch.Size([1, 512, 14, 14])\n",
            "vgg_block_5 output shape:  torch.Size([1, 512, 7, 7])\n",
            "fc output shape:  torch.Size([1, 10])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ePBRpXDfeIf5",
        "colab_type": "text"
      },
      "source": [
        "### 5.7.3 获取数据和训练模型"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "OMoczg1Yexkw",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 739
        },
        "outputId": "d5c3aaca-3852-4667-c606-6d4e6f4d0b01"
      },
      "source": [
        "ratio = 8 \n",
        "small_conv_arch = [(1, 1, 64//ratio), (1, 64//ratio, 128//ratio), (2, 128//ratio, 256//ratio), \n",
        "           (2, 256//ratio, 512//ratio), (2, 512//ratio, 512//ratio)]\n",
        "net = vgg(small_conv_arch, fc_features // ratio, fc_hidden_units // ratio)\n",
        "net"
      ],
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "Sequential(\n",
              "  (vgg_block_1): Sequential(\n",
              "    (0): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
              "    (1): ReLU()\n",
              "    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
              "  )\n",
              "  (vgg_block_2): Sequential(\n",
              "    (0): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
              "    (1): ReLU()\n",
              "    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
              "  )\n",
              "  (vgg_block_3): Sequential(\n",
              "    (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
              "    (1): ReLU()\n",
              "    (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
              "    (3): ReLU()\n",
              "    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
              "  )\n",
              "  (vgg_block_4): Sequential(\n",
              "    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
              "    (1): ReLU()\n",
              "    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
              "    (3): ReLU()\n",
              "    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
              "  )\n",
              "  (vgg_block_5): Sequential(\n",
              "    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
              "    (1): ReLU()\n",
              "    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
              "    (3): ReLU()\n",
              "    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
              "  )\n",
              "  (fc): Sequential(\n",
              "    (0): FlattenLayer()\n",
              "    (1): Linear(in_features=3136, out_features=512, bias=True)\n",
              "    (2): ReLU()\n",
              "    (3): Dropout(p=0.5, inplace=False)\n",
              "    (4): Linear(in_features=512, out_features=512, bias=True)\n",
              "    (5): ReLU()\n",
              "    (6): Dropout(p=0.5, inplace=False)\n",
              "    (7): Linear(in_features=512, out_features=10, bias=True)\n",
              "  )\n",
              ")"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 10
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "CIsKs16PgaOD",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 118
        },
        "outputId": "3e10ca03-3ad3-4d54-f54f-35daf71f7fee"
      },
      "source": [
        "batch_size = 64 \n",
        "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)\n",
        "\n",
        "lr, num_epochs = 0.001, 5\n",
        "optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
        "d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)"
      ],
      "execution_count": 12,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "training on  cuda\n",
            "epoch 1, loss 0.4490, train acc 0.836, test acc 0.870, time 108.5 sec\n",
            "epoch 2, loss 0.3146, train acc 0.885, test acc 0.894, time 108.8 sec\n",
            "epoch 3, loss 0.2711, train acc 0.901, test acc 0.905, time 108.5 sec\n",
            "epoch 4, loss 0.2421, train acc 0.911, test acc 0.915, time 108.2 sec\n",
            "epoch 5, loss 0.2187, train acc 0.921, test acc 0.915, time 108.3 sec\n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}